/*
 * Copyright © 2021 Dowsure
 * https://www.dowsure.com/
 *
 * All rights reserved.
 */

package com.dowsure.apisaas.util.smalgorithm;

import com.dowsure.apisaas.util.StrUtil;
import org.bouncycastle.crypto.BlockCipher;
import org.bouncycastle.crypto.engines.SM4Engine;
import org.bouncycastle.crypto.params.KeyParameter;

import java.security.SecureRandom;

/**
 * sm4加解密工具类
 * <p>因为数据加解密都是对字节数据加解密，因此需要注意加密前和解密后使用的字符集保持一致
 * <p>若无特殊说明，接口接收的都是原始的二进制数据，被hex或者base64编码的数据，务必解码之后再传进来
 * @author Dowsure
 */
public class SM4Util {

    private static final int SM4_ENCRYPT = 1;
    private static final int SM4_DECRYPT = 0;
    public static final int SM4_PKCS8PADDING = 1;
    public static final int SM4_NOPADDING = 0;
    public static final int SM4_KEY_128 = 128;

    public static SM4KeyPair generateSecretKey(int keySize) {
        SM4KeyPair pair = new SM4KeyPair();
        String secretKey = StrUtil.randomStr(keySize);
        String iv = StrUtil.randomStr(keySize);
        pair.setSecretKey(secretKey);
        pair.setIv(iv);
        pair.setSecretKeyByte(secretKey.getBytes());
        pair.setIvByte(iv.getBytes());
        return pair;
    }

    /**
     * 生成sm4密钥，长度使用
     * @param keySize 密钥位数（通过SM4Util的常量获取长度值）
     * @return sm4密钥
     */
    public static byte[] generateKey(int keySize) {
        byte[] key = new byte[keySize / 8];
        SecureRandom sr = new SecureRandom();
        sr.nextBytes(key);

        return key;
    }

    /**
     * sm4 ecb模式加密数据，数据长度非16倍数，则使用默认PKCS8PADDING方式填充
     * @param data 待加密的数据
     * @param key  sm4密钥
     * @return 密文数据
     */
    public static byte[] encryptECB(byte[] data, byte[] key) {
        return encryptECB(data, key, SM4_PKCS8PADDING);
    }

    /**
     * sm4 ecb模式解密数据，使用默认PKCS8PADDING方式去除填充
     * @param cipher 密文数据
     * @param key    sm4密钥
     * @return 明文字节数据
     */
    public static byte[] decryptECB(byte[] cipher, byte[] key) {
        return decryptECB(cipher, key, SM4_PKCS8PADDING);
    }

    /**
     * sm4 CBC模式加密数据，数据长度非16倍数，则使用默认PKCS8PADDING方式填充
     * @param data 待加密数据
     * @param key  sm4密钥
     * @param iv   向量
     * @return 密文数据
     */
    public static byte[] encryptCBC(byte[] data, byte[] key, byte[] iv) {
        return encryptCBC(data, key, iv, SM4_PKCS8PADDING);
    }

    /**
     * sm4 cbc模式解密数据，使用默认PKCS8PADDING方式去除填充
     * @param cipher sm4密文数据
     * @param key    sm4密钥
     * @param iv     向量
     * @return 明文字节数据
     */
    public static byte[] decryptCBC(byte[] cipher, byte[] key, byte[] iv) {
        return decryptCBC(cipher, key, iv, SM4_PKCS8PADDING);
    }

    /**
     * sm4 ecb模式加密数据
     * @param data        待加密数据
     * @param key         sm4密钥
     * @param paddingMode 填充模式，具体支持请看类的常量字段,若使用不支持的模式则会默认无填充
     * @return 返回密文数据
     */
    public static byte[] encryptECB(byte[] data, byte[] key, int paddingMode) {
        BlockCipher engine = new SM4Engine();
        engine.init(true, new KeyParameter(key));
        if (paddingMode == SM4_PKCS8PADDING) {
            data = padding(data, SM4_ENCRYPT);
        } else {
            data = data.clone();
        }
        int length = data.length;
        for (int i = 0; length > 0; length -= 16, i += 16) {
            engine.processBlock(data, i, data, i);
        }
        return data;
    }

    /**
     * sm4 ecb模式解密数据
     * @param cipher      密文数据
     * @param key         sm4密钥
     * @param paddingMode 填充模式，具体支持请看类的常量字段,若使用不支持的模式则会默认无填充
     * @return 返回明文字节数据
     */
    public static byte[] decryptECB(byte[] cipher, byte[] key, int paddingMode) {
        BlockCipher engine = new SM4Engine();
        engine.init(false, new KeyParameter(key));
        int length = cipher.length;
        byte[] tmp = new byte[cipher.length];
        for (int i = 0; length > 0; length -= 16, i += 16) {
            engine.processBlock(cipher, i, tmp, i);
        }
        byte[] plain = null;
        if (paddingMode == SM4_PKCS8PADDING) {
            plain = padding(tmp, SM4_DECRYPT);
        } else {
            plain = tmp;
        }
        return plain;
    }

    /**
     * CBC模式加密数据
     * @param data        待加密数据
     * @param key         密钥
     * @param iv          向量
     * @param paddingMode 填充模式，具体支持请看类的常量字段,若使用不支持的模式则会默认无填充
     * @return 返回密文值
     */
    public static byte[] encryptCBC(byte[] data, byte[] key, byte[] iv, int paddingMode) {
        BlockCipher engine = new SM4Engine();
        engine.init(true, new KeyParameter(key));
        if (paddingMode == SM4_PKCS8PADDING) {
            data = padding(data, SM4_ENCRYPT);
        } else {
            data = data.clone();
        }
        int length = data.length;
        iv = iv.clone();
        for (int i = 0; length > 0; length -= 16, i += 16) {

            for (int j = 0; j < 16; j++) {
                data[i + j] = ((byte) (data[i + j] ^ iv[j]));
            }
            engine.processBlock(data, i, data, i);
            System.arraycopy(data, i, iv, 0, 16);
        }
        return data;
    }

    /**
     * CBC模式解密数据
     * @param cipher    密文数据
     * @param key       密钥
     * @param iv        向量
     * @param paddingMode 填充模式，具体支持请看类的常量字段,若使用不支持的模式则会默认无填充
     * @return 返回明文字节数据
     */
    public static byte[] decryptCBC(byte[] cipher, byte[] key, byte[] iv, int paddingMode) {
        BlockCipher engine = new SM4Engine();
        engine.init(false, new KeyParameter(key));
        int length = cipher.length;
        byte[] plain = new byte[cipher.length];
        iv = iv.clone();
        for (int i = 0; length > 0; length -= 16, i += 16) {

            engine.processBlock(cipher, i, plain, i);
            for (int j = 0; j < 16; j++) {
                plain[j + i] = ((byte) (plain[i + j] ^ iv[j]));
            }
            System.arraycopy(cipher, i, iv, 0, 16);
        }

        byte[] res = null;
        if (paddingMode == SM4_PKCS8PADDING) {
            res = padding(plain, SM4_DECRYPT);
        } else {
            res = plain;
        }
        return res;
    }

    /**
     * PKCS8PADDING标准填充
     * @param input 输入数据
     * @param mode  填充或去除填充
     * @return
     */
    private static byte[] padding(byte[] input, int mode) {
        if (input == null) {
            return null;
        }

        byte[] ret = null;
        if (mode == SM4_ENCRYPT) {
            int p = 16 - input.length % 16;
            ret = new byte[input.length + p];
            System.arraycopy(input, 0, ret, 0, input.length);
            for (int i = 0; i < p; i++) {
                ret[input.length + i] = (byte) p;
            }
        } else {
            int p = input[input.length - 1];
            ret = new byte[input.length - p];
            System.arraycopy(input, 0, ret, 0, input.length - p);
        }
        return ret;
    }
}
